import warnings
from scipy.special import softmax

from torch import nn
from torch.nn import functional as F
import wandb
import numpy as np
from models.wide_resnet import wide_resnet_cifar
from models.resnet import resnet50, resnet18
from models.cnn_3d import Cnn3D
from metrics.calculate_ece_metrics import (
    calculate_ECE_metrics,
)
from metrics.plots import (
    reliability_plot,
    bin_strength_plot,
    roc_no_decision,
)
from metrics.metrics import (
    compute_auc,
)


def evaluate(settings, test_loader, checkpoint_file):
    print(" ---> Starting the test.")
    net = setup_network(settings)
    net = nn.DataParallel(net)
    checkpoint_dict = torch.load(checkpoint_file)
    net.load_state_dict(checkpoint_dict["net_state_dict"])
    net.to(settings.device)
    net.eval()
    labels_np = np.zeros(len(test_loader.dataset))
    predictions_np = np.zeros(len(test_loader.dataset))
    confidences_np = np.zeros((len(test_loader.dataset), settings.num_classes))
    logits_np = np.zeros((len(test_loader.dataset), settings.num_classes))

    test_eq_width_ece = 0
    test_eq_mass_ece = 0
    test_class_wise_ece = 0
    test_accuracy = 0
    test_auc = 0
    total = 0
    correct = 0

    # Identify which checkpoint
    if "best_ece" in checkpoint_file:
        suffix = "best_ece"
    elif "best_acc" in checkpoint_file:
        suffix = "best_acc"
    elif "best_auc" in checkpoint_file:
        suffix = "best_auc"
    settings.suffix = suffix
    # Run the test
    with torch.no_grad():
        for batch_idx, test_data in enumerate(test_loader, 0):
            if settings.dataset == "prostate_mri":
                data_dwi, data_t2ax, test_targets = test_data
                data_dwi, data_t2ax, test_targets = (
                    data_dwi.to(settings.device),
                    data_t2ax.to(settings.device),
                    test_targets.to(settings.device),
                )
                test_outputs = net(data_dwi, data_t2ax)
                test_targets = test_targets.to(torch.int64)

            else:
                data, test_targets = test_data
                if "mnist" in settings.dataset:
                    test_targets = torch.squeeze(test_targets, 1).long()
                data, test_targets = data.to(settings.device), test_targets.to(
                    settings.device
                )
                test_outputs = net(data)

            _, predictions = torch.max(test_outputs, 1)  # Get predictions

            total += test_targets.size(0)  # Count the number of seen samples
            correct += predictions.eq(test_targets).cpu().sum()

            confidences = F.softmax(test_outputs, dim=1).detach().cpu().numpy()
            samples_batch = test_targets.size(0)
            offset = batch_idx * test_loader.batch_size
            logits_np[offset : offset + samples_batch, :] = (
                test_outputs.detach().cpu().numpy()
            )
            labels_np[offset : offset + samples_batch] = (
                test_targets.detach().cpu().numpy()
            )
            predictions_np[offset : offset + samples_batch] = (
                predictions.detach().cpu().numpy()
            )
            if settings.use_temperature_scaling == 0:
                confidences_np[
                    offset : offset + samples_batch,
                    :,
                ] = confidences

    test_accuracy = float((100.0 * correct / total).detach())
    # Rescale logits if test for TS
    if settings.use_temperature_scaling == 1:
        confidences_np = softmax(logits_np / settings.temperature, axis=1)

    # Compute ECE
    (test_eq_mass_ece, test_eq_width_ece, test_class_wise_ece,) = calculate_ECE_metrics(
        confidences_np,
        labels_np,
        test_eq_mass_ece,
        test_eq_width_ece,
        test_class_wise_ece,
    )

    test_auc = compute_auc(confidences_np, labels_np, settings) * 100.0

    if settings.use_temperature_scaling == 1:
        wandb.run.summary["test_eq_mass_ece_TS_" + suffix] = test_eq_mass_ece
        wandb.run.summary["test_eq_width_ece_TS_" + suffix] = test_eq_width_ece
        wandb.run.summary["test_class_wise_ece_TS_" + suffix] = test_class_wise_ece
        wandb.run.summary["test_accuracy_TS_" + suffix] = test_accuracy
        wandb.run.summary["test_auc_TS_" + suffix] = test_auc
    else:
        wandb.run.summary["test_eq_mass_ece_" + suffix] = test_eq_mass_ece
        wandb.run.summary["test_eq_width_ece_" + suffix] = test_eq_width_ece
        wandb.run.summary["test_class_wise_ece_" + suffix] = test_class_wise_ece
        wandb.run.summary["test_accuracy_" + suffix] = test_accuracy
        wandb.run.summary["test_auc_" + suffix] = test_auc

    print(
        "   - Test accuracy {:.3f}, Test EM-ECE {:.3f}, Test AUC {:.3f} for {}.\n".format(
            test_accuracy, test_eq_mass_ece, test_auc, suffix
        ),
    )

    # # Plot % samples without decision VS 1-error without TS
    roc_no_decision(
        confidences_np,
        labels_np,
        settings,
        settings.plots_dir
        + "/roc_{}_ts{}".format(
            settings.model_name,
            str(settings.use_temperature_scaling),
        )
        + suffix,
    )

    confidence_preds = np.amax(confidences_np, axis=1).tolist()
    predictions_np = predictions_np.tolist()
    labels_np = labels_np.tolist()

    reliability_plot(
        confidence_preds,
        predictions_np,
        labels_np,
        settings.plots_dir
        + "/rel_diag_test_ts{}_{:02d}_".format(
            str(settings.use_temperature_scaling), settings.seed
        )
        + suffix,
    )
    bin_strength_plot(
        confidence_preds,
        predictions_np,
        labels_np,
        settings.plots_dir
        + "/bin_strength_test_ts{}_{:02d}_".format(
            str(settings.use_temperature_scaling), settings.seed
        )
        + suffix,
    )

    return test_eq_mass_ece, test_accuracy, test_auc


def setup_network(settings):
    if "cifar" in settings.dataset:
        net = wide_resnet_cifar(
            depth=settings.depth,
            width=settings.widen_factor,
            num_classes=settings.num_classes,
        )
    elif settings.net_type == "resnet50":
        net = resnet50(settings.num_classes)
    elif settings.net_type == "resnet18":
        net = resnet18(settings.num_classes)
    elif settings.net_type == "3d_cnn":
        net = Cnn3D(use_norm=0, hot_enc=1, n_in_dwi=3, n_in_st=1)
    else:
        warnings.warn("Model is not listed.")
    net.to(settings.device)
    return net


def get_new_results(
    settings,
    checkpoint_file,
    test_loader,
    test_em_ece_runs,
    test_acc_runs,
    test_auc_runs,
):
    test_em_ece, test_acc, test_auc = evaluate(settings, test_loader, checkpoint_file)
    test_em_ece_runs.append(test_em_ece)
    test_acc_runs.append(test_acc)
    test_auc_runs.append(test_auc)
